/***************************************************************************
 *   Copyright (C) 2022 by David Register                                  *
 *   Copyright (C) 2022 Alec Leamas                                        *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, see <https://www.gnu.org/licenses/>. *
 **************************************************************************/

/**
 * \file
 *
 * Implement mDNS Query, and friends.
 */

#include <algorithm>
#include <memory>
#include <thread>

#if defined(_WIN32) && !defined(_CRT_SECURE_NO_WARNINGS)
#define _CRT_SECURE_NO_WARNINGS 1
#endif

#include <stdio.h>

#include <errno.h>
#include <signal.h>

#ifdef _WIN32
#include <winsock2.h>
#include <iphlpapi.h>
#define sleep(x) Sleep(x * 1000)
#else
#include <netdb.h>
#include <ifaddrs.h>
#include <net/if.h>
#endif

#include <wx/datetime.h>
#include <wx/log.h>

#ifdef __ANDROID__
#include "androidUTIL.h"
#endif

#include "model/cmdline.h"
#include "mdns_util.h"
#include "model/mdns_cache.h"
#include "model/mdns_query.h"

// Static data structs
std::vector<ocpn_DNS_record_t> g_sk_servers;

static char addrbuffer[64];
static char entrybuffer[256];
static char namebuffer[256];
static char sendbuffer[1024];
static mdns_record_txt_t txtbuffer[128];

static struct sockaddr_in service_address_ipv4;
static struct sockaddr_in6 service_address_ipv6;

static int has_ipv4;
static int has_ipv6;

static void log_printf(const char* fmt, ...) {
  if (getenv("OCPN_MDNS_DEBUG") ||
      wxLog::GetActiveTarget()->GetLogLevel() >= wxLOG_Debug) {
    va_list ap;
    va_start(ap, fmt);
    vprintf(fmt, ap);
    va_end(ap);
  }
}

static int ocpn_query_callback(int sock, const struct sockaddr* from,
                               size_t addrlen, mdns_entry_type_t entry,
                               uint16_t query_id, uint16_t rtype,
                               uint16_t rclass, uint32_t ttl, const void* data,
                               size_t size, size_t name_offset,
                               size_t name_length, size_t record_offset,
                               size_t record_length, void* user_data) {
  (void)sizeof(sock);
  (void)sizeof(query_id);
  (void)sizeof(name_length);
  (void)sizeof(user_data);
  mdns_string_t fromaddrstr =
      ip_address_to_string(addrbuffer, sizeof(addrbuffer), from, addrlen);
  const char* entrytype =
      (entry == MDNS_ENTRYTYPE_ANSWER)
          ? "answer"
          : ((entry == MDNS_ENTRYTYPE_AUTHORITY) ? "authority" : "additional");
  mdns_string_t entrystr = mdns_string_extract(
      data, size, &name_offset, entrybuffer, sizeof(entrybuffer));
  bool is_ipv4 =
      from->sa_family == AF_INET;  // Only ipv4 responses are to be used.

  if ((rtype == MDNS_RECORDTYPE_PTR) && is_ipv4) {
    mdns_string_t namestr =
        mdns_record_parse_ptr(data, size, record_offset, record_length,
                              namebuffer, sizeof(namebuffer));
    log_printf("%.*s : %s %.*s PTR %.*s rclass 0x%x ttl %u length %d\n",
               MDNS_STRING_FORMAT(fromaddrstr), entrytype,
               MDNS_STRING_FORMAT(entrystr), MDNS_STRING_FORMAT(namestr),
               rclass, ttl, (int)record_length);

    std::string srv(namestr.str, namestr.length);
    size_t rh = srv.find("opencpn-object");
    if (rh > 1) rh--;
    std::string hostname = srv.substr(0, rh);

    std::string from(fromaddrstr.str, fromaddrstr.length);
    size_t r = from.find(':');
    std::string ip = from.substr(0, r);

    // Is the destination a portable?  Detect by string inspection.
    std::string port =
        hostname.find("Portable") == std::string::npos ? "8000" : "8001";
    MdnsCache::GetInstance().Add(srv, hostname, ip, port);
  }

  return 0;
}

static int sk_query_callback(int sock, const struct sockaddr* from,
                             size_t addrlen, mdns_entry_type_t entry,
                             uint16_t query_id, uint16_t rtype, uint16_t rclass,
                             uint32_t ttl, const void* data, size_t size,
                             size_t name_offset, size_t name_length,
                             size_t record_offset, size_t record_length,
                             void* user_data) {
  (void)sizeof(sock);
  (void)sizeof(query_id);
  (void)sizeof(name_length);
  (void)sizeof(user_data);
  mdns_string_t fromaddrstr =
      ip_address_to_string(addrbuffer, sizeof(addrbuffer), from, addrlen);
  const char* entrytype =
      (entry == MDNS_ENTRYTYPE_ANSWER)
          ? "answer"
          : ((entry == MDNS_ENTRYTYPE_AUTHORITY) ? "authority" : "additional");
  mdns_string_t entrystr = mdns_string_extract(
      data, size, &name_offset, entrybuffer, sizeof(entrybuffer));
  bool is_ipv4 =
      from->sa_family == AF_INET;  // Only ipv4 responses are to be used.

  if ((rtype == MDNS_RECORDTYPE_PTR) && is_ipv4) {
    mdns_string_t namestr =
        mdns_record_parse_ptr(data, size, record_offset, record_length,
                              namebuffer, sizeof(namebuffer));
    std::string srv(namestr.str, namestr.length);
    size_t rh = srv.find("_signalk-ws");
    if (rh > 1) {
      rh--;
    }
    std::string hostname = srv.substr(0, rh);
    // Remove non-printable characters as seen in names returned by macOS
    hostname.erase(remove_if(hostname.begin(), hostname.end(),
                             [](char c) { return (c < 0); }),
                   hostname.end());
    bool found = false;
    for (const auto& sks : g_sk_servers) {
      if (sks.hostname == hostname) {
        found = true;
        break;
      }
    }
    if (!found) {
      ocpn_DNS_record_t sk_server;
      sk_server.service_instance = srv;
      sk_server.hostname = hostname;
      g_sk_servers.push_back(sk_server);
    }
  } else if ((rtype == MDNS_RECORDTYPE_SRV) && is_ipv4) {
    mdns_record_srv_t srv =
        mdns_record_parse_srv(data, size, record_offset, record_length,
                              namebuffer, sizeof(namebuffer));
    g_sk_servers.back().port = std::to_string(srv.port);
  } else if ((rtype == MDNS_RECORDTYPE_A) && is_ipv4) {
    sockaddr_in addr;
    mdns_record_parse_a(data, size, record_offset, record_length, &addr);
    mdns_string_t addrstr = ipv4_address_to_string(
        namebuffer, sizeof(namebuffer), &addr, sizeof(addr));
    g_sk_servers.back().ip = addrstr.str;
  } else {
    // log_printf("SOMETING ELSE\n");
  }
  return 0;
}

// Send a mDNS query
int send_mdns_query(mdns_query_t* query, size_t count, size_t timeout_secs,
                    mdns_record_callback_fn callback_function) {
  int sockets[32];
  int query_id[32];
  int num_sockets =
      open_client_sockets(sockets, sizeof(sockets) / sizeof(sockets[0]), 0);
  if (num_sockets <= 0) {
    log_printf("Failed to open any client sockets\n");
    return -1;
  }
  log_printf("Opened %d socket%s for mDNS query\n", num_sockets,
             num_sockets ? "s" : "");

  size_t capacity = 2048;
  void* buffer = malloc(capacity);
  void* user_data = 0;

  log_printf("Sending mDNS query");
  for (size_t iq = 0; iq < count; ++iq) {
    const char* record_name = "PTR";
    if (query[iq].type == MDNS_RECORDTYPE_SRV)
      record_name = "SRV";
    else if (query[iq].type == MDNS_RECORDTYPE_A)
      record_name = "A";
    else if (query[iq].type == MDNS_RECORDTYPE_AAAA)
      record_name = "AAAA";
    else
      query[iq].type = MDNS_RECORDTYPE_PTR;
    log_printf(" : %s %s", query[iq].name, record_name);
  }
  log_printf("\n");
  for (int isock = 0; isock < num_sockets; ++isock) {
    query_id[isock] =
        mdns_multiquery_send(sockets[isock], query, count, buffer, capacity, 0);
    if (query_id[isock] < 0)
      log_printf("Failed to send mDNS query: %s\n", strerror(errno));
  }

  // This is a simple implementation that loops for timeout_secs or as long as
  // we get replies
  int res;
  log_printf("Reading mDNS query replies\n");
  int records = 0;
  do {
    struct timeval timeout;
    timeout.tv_sec = timeout_secs;
    timeout.tv_usec = 0;

    int nfds = 0;
    fd_set readfs;
    FD_ZERO(&readfs);
    for (int isock = 0; isock < num_sockets; ++isock) {
      if (sockets[isock] >= nfds) nfds = sockets[isock] + 1;
      FD_SET(sockets[isock], &readfs);
    }

    res = select(nfds, &readfs, 0, 0, &timeout);
    if (res > 0) {
      for (int isock = 0; isock < num_sockets; ++isock) {
        if (FD_ISSET(sockets[isock], &readfs)) {
          int rec =
              mdns_query_recv(sockets[isock], buffer, capacity,
                              callback_function, user_data, query_id[isock]);
          if (rec > 0) records += rec;
        }
        FD_SET(sockets[isock], &readfs);
      }
    }
  } while (res > 0);

  log_printf("Read %d records\n", records);

  free(buffer);

  for (int isock = 0; isock < num_sockets; ++isock)
    mdns_socket_close(sockets[isock]);
  log_printf("Closed socket%s\n", num_sockets ? "s" : "");

  return 0;
}

// Static query definition,
// be careful with thread sync if multiple querries used simultaneously
mdns_query_t s_query;

void FindAllOCPNServers(size_t timeout_secs) {
  s_query.name = "opencpn-object-control-service";
  s_query.type = MDNS_RECORDTYPE_PTR;
  s_query.length = strlen(s_query.name);

  std::thread{send_mdns_query, &s_query, 1, timeout_secs, ocpn_query_callback}
      .detach();
  // send_mdns_query(&query, 1, timeout_secs);
}

void FindAllSignalKServers(size_t timeout_secs) {
  g_sk_servers.clear();
  s_query.name = "_signalk-ws._tcp.local.";
  s_query.type = MDNS_RECORDTYPE_PTR;
  s_query.length = strlen(s_query.name);

  std::thread{send_mdns_query, &s_query, 1, timeout_secs, sk_query_callback}
      .detach();
}

std::vector<std::string> get_local_ipv4_addresses() {
  std::vector<std::string> ret_vec;

#ifdef __ANDROID__
  wxString ipa = androidGetIpV4Address();
  ret_vec.push_back(ipa.ToStdString());
#endif

  // When sending, each socket can only send to one network interface
  // Thus we need to open one socket for each interface and address family
  int num_sockets = 0;

#ifdef _WIN32

  IP_ADAPTER_ADDRESSES* adapter_address = 0;
  ULONG address_size = 8000;
  unsigned int ret;
  unsigned int num_retries = 4;
  do {
    adapter_address = (IP_ADAPTER_ADDRESSES*)malloc(address_size);
    ret = GetAdaptersAddresses(AF_UNSPEC,
                               GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_ANYCAST,
                               0, adapter_address, &address_size);
    if (ret == ERROR_BUFFER_OVERFLOW) {
      free(adapter_address);
      adapter_address = 0;
      address_size *= 2;
    } else {
      break;
    }
  } while (num_retries-- > 0);

  if (!adapter_address || (ret != NO_ERROR)) {
    free(adapter_address);
    log_printf("Failed to get network adapter addresses\n");
    return ret_vec;
  }

  int first_ipv4 = 1;
  int first_ipv6 = 1;
  for (PIP_ADAPTER_ADDRESSES adapter = adapter_address; adapter;
       adapter = adapter->Next) {
    if (adapter->TunnelType == TUNNEL_TYPE_TEREDO) continue;
    if (adapter->OperStatus != IfOperStatusUp) continue;

    for (IP_ADAPTER_UNICAST_ADDRESS* unicast = adapter->FirstUnicastAddress;
         unicast; unicast = unicast->Next) {
      if (unicast->Address.lpSockaddr->sa_family == AF_INET) {
        struct sockaddr_in* saddr =
            (struct sockaddr_in*)unicast->Address.lpSockaddr;
        if ((saddr->sin_addr.S_un.S_un_b.s_b1 != 127) ||
            (saddr->sin_addr.S_un.S_un_b.s_b2 != 0) ||
            (saddr->sin_addr.S_un.S_un_b.s_b3 != 0) ||
            (saddr->sin_addr.S_un.S_un_b.s_b4 != 1)) {
          int log_addr = 0;
          if (first_ipv4) {
            service_address_ipv4 = *saddr;
            first_ipv4 = 0;
            log_addr = 1;
          }
          has_ipv4 = 1;

          char buffer[128];
          mdns_string_t addr = ipv4_address_to_string(
              buffer, sizeof(buffer), saddr, sizeof(struct sockaddr_in));
          std::string addr_string(addr.str, addr.length);
          ret_vec.push_back(addr_string);
        }
      }
    }
  }
  free(adapter_address);

#endif

#if !defined(_WIN32) && !defined(__ANDROID__)

  struct ifaddrs* ifaddr = 0;
  struct ifaddrs* ifa = 0;

  if (getifaddrs(&ifaddr) < 0)
    log_printf("Unable to get interface addresses\n");

  int first_ipv4 = 1;
  int first_ipv6 = 1;
  for (ifa = ifaddr; ifa; ifa = ifa->ifa_next) {
    if (!ifa->ifa_addr) continue;
    if (!(ifa->ifa_flags & IFF_UP) || !(ifa->ifa_flags & IFF_MULTICAST))
      continue;
    if ((ifa->ifa_flags & IFF_LOOPBACK) || (ifa->ifa_flags & IFF_POINTOPOINT))
      continue;

    if (ifa->ifa_addr->sa_family == AF_INET) {
      struct sockaddr_in* saddr = (struct sockaddr_in*)ifa->ifa_addr;
      if (saddr->sin_addr.s_addr != htonl(INADDR_LOOPBACK)) {
        int log_addr = 0;
        if (first_ipv4) {
          service_address_ipv4 = *saddr;
          first_ipv4 = 0;
          log_addr = 1;
        }
        has_ipv4 = 1;

        char buffer[128];
        mdns_string_t addr = ipv4_address_to_string(
            buffer, sizeof(buffer), saddr, sizeof(struct sockaddr_in));
        std::string addr_string(addr.str, addr.length);
        ret_vec.push_back(addr_string);
      }
    }
  }

  freeifaddrs(ifaddr);

#endif

  return ret_vec;
}
